/**
 * 
 */
package edu.berkeley.nlp.PCFGLA;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;

/**
 * @author petrov
 *
 */
public class HierarchicalBinaryRule extends BinaryRule {

	private static final long serialVersionUID = 1L;

	public HierarchicalBinaryRule(HierarchicalBinaryRule b) {
		super(b);
		this.scoreHierarchy = new ArrayList<double[][][]>();
		for (double[][][] scores : b.scoreHierarchy){
			this.scoreHierarchy.add(ArrayUtil.clone(scores));
		}
		this.lastLevel = b.lastLevel;
		this.scores = null;
	}

	// assume for now that the rule being passed in is unsplit
	public HierarchicalBinaryRule(BinaryRule b) {
		super(b);
		this.scoreHierarchy = new ArrayList<double[][][]>();
		double[][][] scoreThisLevel = new double[1][1][1];
		scoreThisLevel[0][0][0] = Math.log(b.scores[0][0][0]);
		scoreHierarchy.add(scoreThisLevel);
		this.lastLevel = 0;
		this.scores = null;
	}

	/*
	 * new stuff below
	 */ 
	
  /**
	 * before: scores[leftSubState][rightSubState][parentSubState] gives score for this rule
	 * now: have a hierarchy of refinements 
	 */
	
	List<double[][][]> scoreHierarchy;
	public int lastLevel = -1;
	
	public void explicitlyComputeScores(int finalLevel, short[] newNumSubStates){
		int newMaxStates = (int)Math.pow(2,finalLevel+1);
		int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]);
		int newLStates = Math.min(newMaxStates, newNumSubStates[this.leftChildState]);
		int newRStates = Math.min(newMaxStates, newNumSubStates[this.rightChildState]);

		this.scores = new double[newLStates][newRStates][newPStates];
		for (int level=0; level<=lastLevel; level++){
			double[][][] scoresThisLevel = scoreHierarchy.get(level); 
			if (scoresThisLevel == null) continue;
			int divisorL = newLStates / scoresThisLevel.length;
			int divisorR = newRStates / scoresThisLevel[0].length;
			int divisorP = newPStates / scoresThisLevel[0][0].length;
			for (int lChild=0; lChild<newLStates; lChild++){
				for (int rChild=0; rChild<newRStates; rChild++){
					for (int parent=0; parent<newPStates; parent++){
						this.scores[lChild][rChild][parent] += scoresThisLevel[lChild/divisorL][rChild/divisorR][parent/divisorP];
					}
				}
			}
		}
		for (int lChild=0; lChild<newLStates; lChild++){
			for (int rChild=0; rChild<newRStates; rChild++){
				for (int parent=0; parent<newPStates; parent++){
					this.scores[lChild][rChild][parent] = Math.exp(scores[lChild][rChild][parent]);
				}
			}
		}
	}

	public double[][][] getLastLevel(){
		return this.scoreHierarchy.get(lastLevel);
	}
	
	public HierarchicalBinaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode) {
		// when splitting on parent, never split on ROOT, but otherwise split everything
		if (mode!=2) throw new Error("Can't split hiereachical rule in this mode!");
		
		int newMaxStates = (int)Math.pow(2,lastLevel+1);
		int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]);
		int newLStates = Math.min(newMaxStates, newNumSubStates[this.leftChildState]);
		int newRStates = Math.min(newMaxStates, newNumSubStates[this.rightChildState]);
			
		double[][][] newScores = new double[newLStates][newRStates][newPStates];
		for (int lChild=0; lChild<newLStates; lChild++){
			for (int rChild=0; rChild<newRStates; rChild++){
				for (int parent=0; parent<newPStates; parent++){
					newScores[lChild][rChild][parent] = random.nextDouble()/100.0;
				}
			}
		}
		HierarchicalBinaryRule newRule = new HierarchicalBinaryRule(this);
		newRule.scoreHierarchy.add(newScores);
		newRule.lastLevel++;
		return newRule;
	}

	public int mergeRule() {
		double[][][] scoresFinalLevel = scoreHierarchy.get(lastLevel);
		boolean allZero = true;
		for (int lChild=0; lChild<scoresFinalLevel.length; lChild++){
			for (int rChild=0; rChild<scoresFinalLevel[0].length; rChild++){
				for (int parent=0; parent<scoresFinalLevel[0][0].length; parent++){
					allZero = allZero && (scoresFinalLevel[lChild][rChild][parent] == 0.0);
				}
			}
		}
		if (allZero) {
			scoresFinalLevel = null;
			scoreHierarchy.remove(lastLevel);
			lastLevel--;
			return 1;
		}
		return 0;
	}
  
  public String toString() {
    Numberer n = Numberer.getGlobalNumberer("tags");
    String lState = (String)n.object(leftChildState);
    String rState = (String)n.object(rightChildState);
    String pState = (String)n.object(parentState);
    StringBuilder sb = new StringBuilder();
    if (scores==null) return pState+" -> "+lState+" "+rState+"\n";
    //sb.append(pState+ " -> "+lState+ " "+rState+ "\n");
    sb.append(pState+" -> "+lState+" "+rState+"\n");
    sb.append(ArrayUtil.toString(scores)+"\n");
    for (double[][][] s : scoreHierarchy){
    	sb.append(ArrayUtil.toString(s)+"\n");
    }
    sb.append("\n");
    return sb.toString();
  }

  public int countNonZeroFeatures(){
  	int total = 0;
		for (int level=0; level<=lastLevel; level++){
			double[][][] scoresThisLevel = scoreHierarchy.get(level); 
			if (scoresThisLevel == null) continue;
			for (int lChild=0; lChild<scoresThisLevel.length; lChild++){
				for (int rChild=0; rChild<scoresThisLevel.length; rChild++){
					for (int parent=0; parent<scoresThisLevel.length; parent++){
						if (scoresThisLevel[lChild][rChild][parent]!=0) total++;
					}
				}
			}
		}
		return total;
  }
  
}
